#!/usr/bin/env python  
#-*- coding:utf-8 _*-  
""" 
@author:hello_life 
@license: Apache Licence 
@file: albert_config.py 
@time: 2022/04/21
@software: PyCharm 
description:
"""
import os

import torch

class Config():
    def __init__(self):
        dir_path=os.path.abspath(os.getcwd())

        #预训练模型路径
        self.model_path=os.path.join(dir_path,"from_pretrained\\")
        #训练数据路径
        self.train_path=os.path.join(dir_path,"data\\IMDB_train.csv")
        self.test_path = os.path.join(dir_path, "data\\IMDB_test.csv")
        self.val_path = os.path.join(dir_path, "data\\IMDB_val.csv")
        self.batch_size=8
        self.max_length=10
        #模型存储路径
        self.save_model=os.path.join(dir_path,"save_model\\save_model.pt")

        #数据标签
        self.num_class=2
        self.epoches=1

        #学习率
        self.lr=1e-5

        #cpu or cuda
        self.device="cuda" if torch.cuda.is_available() else "cpu"

if __name__ == '__main__':
    config=Config()
    print(config.device)